{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "T4",
      "authorship_tag": "ABX9TyO3+qbEnufQLsoi2VvofRJ8",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "ce1f77a753394dc5a25e5470fac18560": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_7bb6256651a142eabc61984dfe5d379f",
              "IPY_MODEL_12154fd312434260b7f6779a857e1a82",
              "IPY_MODEL_2e44353a59c4480a8e877d842ad16061"
            ],
            "layout": "IPY_MODEL_abdc9ab22ec049938855373effaf1504"
          }
        },
        "7bb6256651a142eabc61984dfe5d379f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_090b3eaf7d2548ee867fc7d9ddf67523",
            "placeholder": "​",
            "style": "IPY_MODEL_2f2dd26e18ca47dfae4ff33dbb869c0f",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "12154fd312434260b7f6779a857e1a82": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_e5f6717710074184b78c30f4668be2b5",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_222a8b16e19140269c44afffbca96865",
            "value": 2
          }
        },
        "2e44353a59c4480a8e877d842ad16061": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c3fd940f4dd34ce5b7a462e2bf6f1f71",
            "placeholder": "​",
            "style": "IPY_MODEL_6264e61a163b4a5dbdb854f7e2ff3056",
            "value": " 2/2 [00:09&lt;00:00,  4.59s/it]"
          }
        },
        "abdc9ab22ec049938855373effaf1504": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "090b3eaf7d2548ee867fc7d9ddf67523": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2f2dd26e18ca47dfae4ff33dbb869c0f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "e5f6717710074184b78c30f4668be2b5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "222a8b16e19140269c44afffbca96865": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "c3fd940f4dd34ce5b7a462e2bf6f1f71": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6264e61a163b4a5dbdb854f7e2ff3056": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Mistral/Supervised_fine_tuning_(SFT)_of_an_LLM_using_Hugging_Face_tooling.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Supervised fine-tuning (SFT) of an LLM\n",
        "\n",
        "Recall that creating a ChatGPT at home involves 3 steps:\n",
        "\n",
        "1. pre-training a large language model (LLM) to predict the next token on internet-scale data, on clusters of thousands of GPUs. One calls the result a \"base model\"\n",
        "2. supervised fine-tuning (SFT) to turn the base model into a useful assistant\n",
        "3. human preference fine-tuning which increases the assistant's friendliness, helpfulness and safety.\n",
        "\n",
        "In this notebook, we're going to illustrate step 2. This involves supervised fine-tuning (SFT for short), also called instruction tuning.\n",
        "\n",
        "Supervised fine-tuning takes in a \"base model\" from step 1, i.e. a model that has been pre-trained on predicting the next token on internet text, and turns it into a \"chatbot\"/\"assistant\". This is done by fine-tuning the model on human instruction data, using the cross-entropy loss. This means that the model is still trained to predict the next token, although we now want the model to generate useful completions given an instruction like \"what are 10 things to do in London?\", \"How can I make pancakes?\" or \"Write me a poem about elephants\".\n",
        "\n",
        "To do this, one requires human annotators to collect useful completions, on which we can train the model. OpenAI for instance [hired human contractors for this](https://gizmodo.com/chatgpt-openai-ai-contractors-15-dollars-per-hour-1850415474), which were asked to generate useful completions given instructions, like \"In London, you can visit the Big Ben and (...)\". A nice collection of openly available SFT datasets can be found [here](https://huggingface.co/collections/HuggingFaceH4/awesome-sft-datasets-65788b571bf8e371c4e4241a).\n",
        "\n",
        "This way, the model becomes more useful: rather than simply predicting the next token (which might give undesirable outputs, like generating follow-up questions rather than answering the question), we now make it more likely that the model will output useful completions for any instruction we give it. We basically steer it in the direction of generating useful completions which a human could have written given any instruction.\n",
        "\n",
        "Notes:\n",
        "\n",
        "* the entire notebook is based on and can be seen as an annotated version of the [Alignment Handbook](https://github.com/huggingface/alignment-handbook) developed by Hugging Face, and more specifically the [recipe](https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/sft/config_lora.yaml) used to train Zephyr-7b-beta. Huge kudos to the team for creating this!\n",
        "* this notebook applies to any decoder-only LLM available in the Transformers library. In this notebook, we are going to fine-tune the [Mistral-7B base model](https://huggingface.co/mistralai/Mistral-7B-v0.1), which is one of the best open-source large language models at the time of writing."
      ],
      "metadata": {
        "id": "_oPAJl2uDmil"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Required hardware\n",
        "\n",
        "The notebook is designed to be run on any NVIDIA GPU which has the [Ampere architecture](https://en.wikipedia.org/wiki/Ampere_(microarchitecture)) or later with at least 24GB of RAM. This includes:\n",
        "\n",
        "* NVIDIA RTX 3090, 4090\n",
        "* NVIDIA A100, H100, H200\n",
        "\n",
        "and so on. Personally I'm running the notebook on an RTX 4090 with 24GB of RAM.\n",
        "\n",
        "The reason for an Ampere requirement is because we're going to use the [bfloat16 (bf16) format](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format), which is not supported on older architectures like Turing.\n",
        "\n",
        "But: a few tweaks can be made to train the model in float16 (fp16), which is supported by older GPUs like:\n",
        "\n",
        "* NVIDIA RTX 2080\n",
        "* NVIDIA Tesla T4\n",
        "* NVIDIA V100.\n",
        "\n",
        "Comments are added regarding where to swap bf16 with fp16.\n",
        "\n",
        "## Set-up environment\n",
        "\n",
        "Let's start by installing all the 🤗 goodies we need to do supervised fine-tuning. We're going to use\n",
        "\n",
        "* Transformers for the LLM which we're going to fine-tune\n",
        "* Datasets for loading a SFT dataset from the 🤗 hub, and preparing it for the model\n",
        "* BitsandBytes and PEFT for fine-tuning the model on consumer hardware, leveraging [Q-LoRa](https://huggingface.co/blog/4bit-transformers-bitsandbytes), a technique which drastically reduces the compute requirements for fine-tuning\n",
        "* TRL, a [library](https://huggingface.co/docs/trl/index) which includes useful Trainer classes for LLM fine-tuning."
      ],
      "metadata": {
        "id": "Bjdw05Rk5fYm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q transformers[torch] datasets"
      ],
      "metadata": {
        "id": "-9357hGFRqdi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q bitsandbytes trl peft"
      ],
      "metadata": {
        "id": "rxpq51gW_ySc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We also install [Flash Attention](https://github.com/Dao-AILab/flash-attention), which speeds up the attention computations of the model."
      ],
      "metadata": {
        "id": "0RgDwMbXpC4L"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install flash-attn --no-build-isolation"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iNJvtR1wGxHm",
        "outputId": "545b75a3-932e-4195-f74a-a6f3447f67de"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: flash-attn in /usr/local/lib/python3.10/dist-packages (2.4.2)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from flash-attn) (2.1.0+cu121)\n",
            "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from flash-attn) (0.7.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from flash-attn) (23.2)\n",
            "Requirement already satisfied: ninja in /usr/local/lib/python3.10/dist-packages (from flash-attn) (1.11.1.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->flash-attn) (3.13.1)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->flash-attn) (4.5.0)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->flash-attn) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->flash-attn) (3.2.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->flash-attn) (3.1.2)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->flash-attn) (2023.6.0)\n",
            "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->flash-attn) (2.1.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->flash-attn) (2.1.3)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->flash-attn) (1.3.0)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load dataset\n",
        "\n",
        "Note: the alignment handbook supports mixing several datasets, each with a certain portion of training examples. However, the Zephyr recipe only includes a single dataset, which is the [UltraChat200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)."
      ],
      "metadata": {
        "id": "hJIqlyI15kj8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "# based on config\n",
        "raw_datasets = load_dataset(\"HuggingFaceH4/ultrachat_200k\")"
      ],
      "metadata": {
        "id": "YhYvRDF25j59"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The dataset contains various splits, each with a certain number of rows. In our case, as we're going to do supervised fine-tuning (SFT), only the \"train_sft\" and \"test_sft\" splits are relevant for us."
      ],
      "metadata": {
        "id": "WFYvJUfODabt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import DatasetDict\n",
        "\n",
        "# remove this when done debugging\n",
        "indices = range(0,100)\n",
        "\n",
        "dataset_dict = {\"train\": raw_datasets[\"train_sft\"].select(indices),\n",
        "                \"test\": raw_datasets[\"test_sft\"].select(indices)}\n",
        "\n",
        "raw_datasets = DatasetDict(dataset_dict)\n",
        "raw_datasets"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jX1sIK6X6Opi",
        "outputId": "f0ab2b40-6789-44c4-ece2-bd84e0d5fa65"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['prompt', 'prompt_id', 'messages'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['prompt', 'prompt_id', 'messages'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "})"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's check one example. The important thing is that each example should contain a list of messages:"
      ],
      "metadata": {
        "id": "sByKH4hd8mUm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "example = raw_datasets[\"train\"][0]\n",
        "print(example.keys())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "P2XwVpT17TAb",
        "outputId": "8d441fee-4a0b-4310-9a30-86b4439e0609"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "dict_keys(['prompt', 'prompt_id', 'messages'])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Each message is a dictionary containing 2 keys, namely:\n",
        "\n",
        "* \"role\": specifies who the creator of the message is (could be \"system\", \"assistant\" or \"user\" - the latter referring to a human).\n",
        "* \"content\": the actual content of the message.\n",
        "\n",
        "Let's print out the sequence of messages for this training example:"
      ],
      "metadata": {
        "id": "HaWnT4uBCjTy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "messages = example[\"messages\"]\n",
        "for message in messages:\n",
        "  role = message[\"role\"]\n",
        "  content = message[\"content\"]\n",
        "  print('{0:20}:  {1}'.format(role, content))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IQ1sMda27Zj6",
        "outputId": "04b5ab19-2910-4f1b-91cb-3aece687e49b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "user                :  These instructions apply to section-based themes (Responsive 6.0+, Retina 4.0+, Parallax 3.0+ Turbo 2.0+, Mobilia 5.0+). What theme version am I using?\n",
            "On your Collections pages & Featured Collections sections, you can easily show the secondary image of a product on hover by enabling one of the theme's built-in settings!\n",
            "Your Collection pages & Featured Collections sections will now display the secondary product image just by hovering over that product image thumbnail.\n",
            "Does this feature apply to all sections of the theme or just specific ones as listed in the text material?\n",
            "assistant           :  This feature only applies to Collection pages and Featured Collections sections of the section-based themes listed in the text material.\n",
            "user                :  Can you guide me through the process of enabling the secondary image hover feature on my Collection pages and Featured Collections sections?\n",
            "assistant           :  Sure, here are the steps to enable the secondary image hover feature on your Collection pages and Featured Collections sections:\n",
            "\n",
            "1. Log in to your Shopify account and go to your Online Store.\n",
            "2. Click on Customize theme for the section-based theme you are using.\n",
            "3. Navigate to the Collection pages or Featured Collections section where you want to enable the secondary image hover feature.\n",
            "4. With the section open, click on the section's setting (gear) icon in the top-left corner.\n",
            "5. In the settings panel that appears, look for an option labeled 'Image display' or 'Image hover'.\n",
            "6. If available, select 'Show secondary image on hover'.\n",
            "7. Save the changes and preview the Collection/Featured Collection page to see the effect.\n",
            "\n",
            "If you're having trouble finding the setting, the best thing to do is refer to your theme's documentation, since the location and labeling of settings can vary between themes.\n",
            "user                :  Can you provide me with a link to the documentation for my theme?\n",
            "assistant           :  I don't have access to your store's theme information. However, you can usually find the documentation for your theme by going to the shopify theme store, finding your theme and clicking on the 'support' link located in the bottom right corner of the page. Alternatively, you can do a google search for the name of your theme followed by 'documentation' or 'user guide'.\n",
            "user                :  Can you confirm if this feature also works for the Quick Shop section of my theme?\n",
            "assistant           :  The secondary image hover feature may or may not work for your Quick Shop section, depending on the configuration of your theme. Some themes include this feature in the Quick Shop section by default, while others may require additional customization. To check if this feature is available for the Quick Shop section of your theme, follow these steps:\n",
            "\n",
            "1. Go to the Quick Shop section where you would like to enable the feature. 2. Click on the Quick Shop settings icon (gear icon) and look for 'Image display' or 'Image hover'. 3. If available, select 'Show secondary image on hover'. 4. Save the changes. If this option is not available in your Quick Shop section settings, you may need to reach out to your theme developer for assistance with customizing your Quick Shop section to include this feature.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "In this case, it looks like the instructions are about enabling certain features in Shopify. Interesting!"
      ],
      "metadata": {
        "id": "0DL8S3dkT2Av"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load tokenizer\n",
        "\n",
        "Next, we instantiate the tokenizer, which is required to prepare the text for the model. The model doesn't directly take strings as input, but rather `input_ids`, which represent integer indices in the vocabulary of a Transformer model. Refer to my [YouTube video](https://www.youtube.com/watch?v=IGu7ivuy1Ag&ab_channel=NielsRogge) if you want to know more about it.\n",
        "\n",
        "We also set some attributes which the tokenizer of a base model typically doesn't have set, such as:\n",
        "\n",
        "- the padding token ID. During pre-training, one doesn't need to pad since one just creates blocks of text to predict the next token, but during fine-tuning, we will need to pad the (instruction, completion) pairs in order to create batches of equal length.\n",
        "- the model max length: this is required in order to truncate sequences which are too long for the model. Here we decide to train on at most 2048 tokens.\n",
        "- the chat template. A [chat template](https://huggingface.co/blog/chat-templates) determines how each list of messages is turned into a tokenizable string, by adding special strings in between such as `<|user|>` to indicate a user message and `<|assistant|>` to indicate the chatbot's response. Here we define the default chat template, used by most chat models. See also the [docs](https://huggingface.co/docs/transformers/main/en/chat_templating)."
      ],
      "metadata": {
        "id": "RVRxCJQ76spF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "model_id = \"mistralai/Mistral-7B-v0.1\"\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "\n",
        "# set pad_token_id equal to the eos_token_id if not set\n",
        "if tokenizer.pad_token_id is None:\n",
        "  tokenizer.pad_token_id = tokenizer.eos_token_id\n",
        "\n",
        "# Set reasonable default for models without max length\n",
        "if tokenizer.model_max_length > 100_000:\n",
        "  tokenizer.model_max_length = 2048\n",
        "\n",
        "# Set chat template\n",
        "DEFAULT_CHAT_TEMPLATE = \"{% for message in messages %}\\n{% if message['role'] == 'user' %}\\n{{ '<|user|>\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'system' %}\\n{{ '<|system|>\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'assistant' %}\\n{{ '<|assistant|>\\n'  + message['content'] + eos_token }}\\n{% endif %}\\n{% if loop.last and add_generation_prompt %}\\n{{ '<|assistant|>' }}\\n{% endif %}\\n{% endfor %}\"\n",
        "tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE"
      ],
      "metadata": {
        "id": "VvIfUqjK6ntu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Apply chat template\n",
        "\n",
        "Once we have equipped the tokenizer with the appropriate attributes, it's time to apply the chat template to each list of messages. Here we basically turn each list of (instruction, completion) messages into a tokenizable string for the model.\n",
        "\n",
        "Note that we specify `tokenize=False` here, since the `SFTTrainer` which we'll define later on will perform the tokenization internally. Here we only turn the list of messages into strings with the same format."
      ],
      "metadata": {
        "id": "3UNHQsTJ7O6I"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import re\n",
        "import random\n",
        "from multiprocessing import cpu_count\n",
        "\n",
        "def apply_chat_template(example, tokenizer):\n",
        "    messages = example[\"messages\"]\n",
        "    # We add an empty system message if there is none\n",
        "    if messages[0][\"role\"] != \"system\":\n",
        "        messages.insert(0, {\"role\": \"system\", \"content\": \"\"})\n",
        "    example[\"text\"] = tokenizer.apply_chat_template(messages, tokenize=False)\n",
        "\n",
        "    return example\n",
        "\n",
        "column_names = list(raw_datasets[\"train\"].features)\n",
        "raw_datasets = raw_datasets.map(apply_chat_template,\n",
        "                                num_proc=cpu_count(),\n",
        "                                fn_kwargs={\"tokenizer\": tokenizer},\n",
        "                                remove_columns=column_names,\n",
        "                                desc=\"Applying chat template\",)\n",
        "\n",
        "# create the splits\n",
        "train_dataset = raw_datasets[\"train\"]\n",
        "eval_dataset = raw_datasets[\"test\"]\n",
        "\n",
        "for index in random.sample(range(len(raw_datasets[\"train\"])), 3):\n",
        "  print(f\"Sample {index} of the processed training set:\\n\\n{raw_datasets['train'][index]['text']}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "44kIpOXa7Ep4",
        "outputId": "2895c3ad-d3c6-481f-8f96-64481dfbb977"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sample 17 of the processed training set:\n",
            "\n",
            "<|system|>\n",
            "</s>\n",
            "<|user|>\n",
            "Is it possible to have a combination of drawers, cupboards, and wine racks in the same sideboard?: Hand Made any size, any combination or drawers, cupboards, shelving, glazed, solid doors, wine racks.\n",
            "Stunning Oak top shown here but can be made with pine, Reclaimed Timber, Driftwood, Rustic Plank and beech tops also.\n",
            "4 Door Oslo sideboard as shown from £815. Can be made any size with your combination of drawers, cupboards, wine racks, shelving, door and plinth styles available, choice of different worktops (shown here with 20mm bull nose oak top) and finished in any Farrow and Ball colour of your choice.\n",
            "n.b prices are correct as of 01-05/16 but may be subject to change – please contact us for a quote based on your requirements, sizes and needs.</s>\n",
            "<|assistant|>\n",
            "Yes, it is possible to have a combination of drawers, cupboards, and wine racks in the same sideboard. The Hand Made sideboard can be customized to include any combination of these features, as well as shelving, glazed or solid doors, and different types of worktops. Prices for custom orders may vary, so interested customers should contact the seller for a quote based on their specific needs and requirements.</s>\n",
            "<|user|>\n",
            "Can you provide me with the contact information of the seller of the Hand Made sideboard? I would like to request a quote based on my specific needs and requirements.</s>\n",
            "<|assistant|>\n",
            "I do not have access to specific information about sellers or their contact details. However, if you have seen this sideboard online or in a physical store, you may be able to find the seller's contact information on their website or by reaching out to the store directly. Alternatively, you could search for similar handcrafted furniture makers in your area or online and reach out to them for custom orders.</s>\n",
            "<|user|>\n",
            "Can you provide me with some examples of the different worktops available for the Hand Made sideboard?</s>\n",
            "<|assistant|>\n",
            "Certainly, here are some examples of the different worktops available for the Hand Made sideboard:\n",
            "\n",
            "1. Oak - this is a classic and timeless choice that is known for its durability and natural beauty.\n",
            "2. Pine - a softer wood that has a warm, rustic feel and is often used in traditional or country-style interiors.\n",
            "3. Reclaimed Timber - this is wood that has been salvaged from old buildings, furniture or other sources, and can add character and history to a piece.\n",
            "4. Driftwood - a unique and striking option that is made from wood that has been washed up on shore and then treated and smoothed for use in furniture.\n",
            "5. Rustic Plank - this is a style of wood where individual planks of wood are combined to create a surface that has a rough-hewn, textured appearance.\n",
            "6. Beech - a smooth and hard-wearing option that has a uniform grain and a light, pale color.\n",
            "\n",
            "These are just a few examples of the different worktops available, and the options may vary based on the seller's availability and preferences.</s>\n",
            "\n",
            "Sample 29 of the processed training set:\n",
            "\n",
            "<|system|>\n",
            "</s>\n",
            "<|user|>\n",
            "Design a series of posters to promote bicycle commuting.</s>\n",
            "<|assistant|>\n",
            "I do not have the ability to create visual posters, but I can provide you with some taglines and content ideas that could be used in promoting bicycle commuting:\n",
            "\n",
            "1. \"riding a bike to work saves money, the environment, and your health!\"\n",
            "2. \"get exercise, fresh air, and beat traffic by riding your bike to work.\"\n",
            "3. \"join the bike-to-work movement! help reduce carbon emissions and promote sustainable living, one pedal stroke at a time!\"\n",
            "4. \"who needs a gym when you can burn calories on your bike commute?\"\n",
            "5. \"experience the freedom and joy of cycling while reducing congestion on the roads.\"\n",
            "6. \"riding a bike is good for your body, mind, and wallet. Give it a try!\"\n",
            "7. \"less time stuck in traffic. More time enjoying the open road. Ride to work and feel alive!\"\n",
            "8. \"reduce your carbon footprint and get in shape by biking to work. It’s a win-win for you and the planet!\"\n",
            "9. \"bike to work and explore your city in a new way. Discover hidden treasures and enjoy the scenery.\"\n",
            "10. \"biking to work is not just a mode of transportation. It's a lifestyle that promotes health, wellness, and sustainability.\"</s>\n",
            "<|user|>\n",
            "Could you possibly provide me with more information on the benefits of bike commuting, like specific health benefits or the environmental impact?</s>\n",
            "<|assistant|>\n",
            "Certainly! Here are some additional benefits of bike commuting that you could use in your posters or other promotional materials:\n",
            "\n",
            "1. Health benefits: Biking is a great way to stay active, which can have numerous health benefits such as reducing the risk of obesity, heart disease, stroke, and type 2 diabetes. It can also improve mental health by reducing stress levels. 2. Environmental impact: Bike commuting is a sustainable transportation option that does not contribute to air pollution, noise pollution, or greenhouse gas emissions that are associated with motorized vehicles. By choosing to bike instead of drive, you can help to reduce your carbon footprint and protect the environment. 3. Economic benefits: Biking can be a more cost-effective mode of transportation. By avoiding the cost of gas, insurance, and car maintenance, you can save money in the long run. Additionally, some cities may offer incentives or tax credits to encourage bike commuting. 4. Improved productivity: Biking to work can help to increase alertness and productivity by providing a morning exercise routine that can help you feel more energized throughout the day. 5. Free parking: Unlike cars, bikes do not require parking spaces or fees, which means that you can save money and time by not having to search for parking. These are just a few of the many benefits of biking to work. By highlighting these benefits, you can help to encourage more people to give bike commuting a try.</s>\n",
            "<|user|>\n",
            "Could you provide me with some statistics to back up these points?</s>\n",
            "<|assistant|>\n",
            "Certainly! Here are some statistics that you could use to support the benefits of bike commuting in your posters:\n",
            "\n",
            "1. Health benefits:\n",
            "   - According to the World Health Organization, physical inactivity is the fourth leading risk factor for global mortality. Regular physical activity, such as cycling, can reduce the risk of non-communicable diseases, such as heart disease and stroke, by up to 30%. - Cycling for just 30 minutes per day can burn up to 500 calories. 2. Environmental impact:\n",
            "   - According to the European Cyclists' Federation, a person cycling 10 km to work and back every day can save over 1,500 kg of CO2 emissions per year, compared to traveling by car. - The Union of Concerned Scientists estimates that if Americans replaced just one car trip per day with a bike ride, it would save 2 billion gallons of gas and reduce carbon emissions by 22 billion pounds per year. 3. Economic benefits:\n",
            "   - According to the Bicycle Commuter Act, a tax credit of up to $240 per year is available for employees who bike to work. - A study from the University of California found that employers can save up to $7,000 per year for each employee who bikes to work instead of driving. 4. Improved productivity:\n",
            "   - A study from the University of Bristol found that employees who ride their bikes to work are less likely to experience a dip in concentration and energy mid-morning and mid-afternoon, leading to improved productivity. 5. Free parking:\n",
            "   - According to the National Association of City Transportation Officials, a single car parking spot can cost between $5,000 and $50,000 to build and maintain, whereas a bike parking facility can support dozens of bikes at a fraction of the cost. These statistics can help to make your posters more compelling and informative, and they can also make it easier for readers to understand the potential benefits of bike commuting.</s>\n",
            "<|user|>\n",
            "Could you also provide some tips for bike commuting beginners, like what gear they may need or safety precautions they should take?</s>\n",
            "<|assistant|>\n",
            "Sure! Here are some tips for beginners who want to start bike commuting:\n",
            "\n",
            "1. Get the right gear: Good gear can make your ride more comfortable and safer. You will need a helmet, a sturdy bike lock or secure parking space, and clothing appropriate for the weather conditions. If you are biking at night or in low-light conditions, be sure to wear reflective clothing or add reflective tape to your bike or helmet. 2. Plan your route: Research your route and choose one that is safe and convenient for your skill level. Consider bike-friendly roads or bike lanes, and avoid busy streets or heavy traffic areas. Many cities have online resources or maps of bike routes that you can use to plan your commute. 3. Practice safety precautions: Always obey traffic laws, including stop signs and traffic signals. Use hand signals to indicate when you are turning or stopping. Be visible, especially in low-light conditions, by using lights and reflective clothing. Always wear a helmet to protect your head. 4. Start slowly: If you are new to biking or haven't ridden in a while, start with short rides and gradually increase your distance and speed. Be patient with yourself and don't push yourself too hard. 5. Be prepared: Bring a small toolkit with you, including a spare tube, tire levers, and a pump. Make sure your bike is in good working condition before you set out on a ride. 6. Consider taking a bike safety class: Many cities offer bike safety classes that can help you learn the rules of the road and improve your cycling skills. Following these tips can help make your bike commute more enjoyable, safe, and easy.</s>\n",
            "\n",
            "Sample 61 of the processed training set:\n",
            "\n",
            "<|system|>\n",
            "</s>\n",
            "<|user|>\n",
            "What kind of collaborative work is expected of the Art Teacher at The School of Arts and Sciences?\n",
            "Generate according to: The School of Arts and Sciences, at 3208 Thomasville Rd, Tallahassee, FL is seeking applications for a Kindergarten-8th grade Art Teacher to teach in a multi-aged classroom. Applicants should be able to plan and execute rigorous, teacher created, hands-on thematic and project based lessons. They will also implement and support instruction and assessment practices for a diverse group of students that reflect the Florida Standards and the principles of their Charter. Also, they will work collaboratively with a diverse group of school colleagues, parents, community members and others to support and implement the goals of the school. Portfolio assessment experience. Salary and benefits commensurate with education and experience; eligible for benefits.\n",
            "If interested, please submit a resume and cover letter with SAST ART TEACHER in the subject line by May 11, 2018 to Eirin Lombardo, Principal at lombardoe@leonschools.net.</s>\n",
            "<|assistant|>\n",
            "As an Art Teacher at The School of Arts and Sciences, the successful applicant is expected to work collaboratively with a diverse group of school colleagues, parents, community members, and others. This includes working with other teachers and staff to integrate art into other subjects such as history, literature, and science. The applicant should also be able to participate in meetings and professional development activities that aim to improve instruction for students.\n",
            "\n",
            "The Art Teacher is expected to create and deliver rigorous, teacher-created, hands-on thematic and project-based lessons that reflect the Florida Standards and the principles of the school’s Charter. The successful applicant should be able to plan and execute lessons that are appropriate for a multi-aged classroom that includes Kindergarten to 8th grade students. They should be able to use various techniques, tools, and materials to develop students’ artistic abilities.\n",
            "\n",
            "The Art Teacher should also have experience in portfolio assessment and be able to implement and support instruction and assessment practices for a diverse group of students. They should be able to provide feedback to students on their artwork and help them reflect on their growth and development.\n",
            "\n",
            "In summary, the Art Teacher at The School of Arts and Sciences is expected to work collaboratively, plan and deliver rigorous, hands-on lessons, support instruction and assessment practices, work with a diverse group of students, and have experience in portfolio assessment.</s>\n",
            "<|user|>\n",
            "Can you give me some examples of how the Art Teacher at The School of Arts and Sciences can integrate art into other subjects?</s>\n",
            "<|assistant|>\n",
            "Certainly! Here are some examples of how the Art Teacher at The School of Arts and Sciences can integrate art into other subjects:\n",
            "\n",
            "1. History - The Art Teacher can teach students about a specific historical period and then have them create artwork that reflects the time period. For example, students could create Egyptian hieroglyphs or Greek pottery.\n",
            "\n",
            "2. Literature - The Art Teacher can have students create artwork based on a book they have read. For example, students could create a painting or sculpture based on characters from a book.\n",
            "\n",
            "3. Science - The Art Teacher can have students create artwork that reflects scientific concepts. For example, students could create a sculpture of an atom or create an ecosystem painting.\n",
            "\n",
            "4. Mathematics - The Art Teacher can teach students the principles of symmetry and have them create symmetrical artwork. For example, students could create a symmetrical mandala.\n",
            "\n",
            "5. Social Studies - The Art Teacher can have students create maps or globes to learn about geography. Students can also create posters or collages about different cultures.\n",
            "\n",
            "These are just a few examples, but there are many ways the Art Teacher can integrate art into other subjects to enhance students’ learning.</s>\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We also specified `remove_columns` to the map function above, meaning that we are now left with only 1 column: \"text\".\n",
        "\n",
        "Hence the set-up is now very similar to pre-training: we will just train the model predict the next token, given the previous ones. In this case, the model will learn to generate completions given instructions.\n",
        "\n",
        "Hence, similar to pre-training, the labels will be created automatically based on the inputs (by shifting them one position to the right). The model is still trained using cross-entropy. This means that evaluation will mostly be done by checking perplexity/validation loss/model generations."
      ],
      "metadata": {
        "id": "Ro7VPkBLS8XG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "raw_datasets"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "L_tvjW-Y-uBT",
        "outputId": "631fdd9f-c4ac-4824-cf6f-7aded04ea5ca"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['text'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['text'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "})"
            ]
          },
          "metadata": {},
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define model arguments\n",
        "\n",
        "Next, it's time to define the model arguments.\n",
        "\n",
        "Here, some explanation is required regarding ways to fine-tune model.\n",
        "\n",
        "### Full fine-tuning\n",
        "\n",
        "Typically, one performs \"full fine-tuning\": this means that we will simply update all the weights of the base model during fine-tuning. This is then typically done either in full precision (float32), or mixed precision (a combination of float32 and float16). However, with ever larger models like LLMs, this becomes infeasible.\n",
        "\n",
        "For reference, float32 means that each parameter of a model gets saved in 32 bits or 4 bytes. Hence, for a 7 billion parameter model like Mistral-7B, one requires 7 billion parameters \\* 4 bytes per parameter = 28 GB of GPU RAM, just to load the model. During training with an optimizer like AdamW, one not only requires memory for the model but also for the gradients and optimizer states, which roughly comes down to approximately 18 times the size of the model in gigabytes when training with mixed precision, in this case 7 * 18 = 126 GB of GPU RAM. And that's just for a 7B parameter model! See the guide for more info: https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one.\n",
        "\n",
        "### LoRa, a PEFT method\n",
        "\n",
        "Hence, some clever people at Microsoft have come up with a method called [LoRa](https://huggingface.co/docs/peft/conceptual_guides/lora) (low-rank adaptation). The idea here is that, rather than performing full fine-tuning, we are going to freeze the existing model and only add a few parameter weights to the model (called \"adapters\"), which we're going to train.\n",
        "\n",
        "LoRa is what we call a parameter-efficient fine-tuning (PEFT) method. It's a popular method for fine-tuning models in a parameter-efficient way, by only training a few adapters, keeping the existing model untouched. LoRa is available in the [PEFT library](https://huggingface.co/docs/peft/v0.7.1/en/index) by Hugging Face, which also supports various other PEFT methods (but LoRa is the most popular one at the time of writing).\n",
        "\n",
        "### QLoRa, an even more efficient method\n",
        "\n",
        "With regular LoRa, one would keep the base model in 32 or 16 bits in memory, and then train the parameter weights. However, there have been new methods developed to shrink the size of a model considerably, to 8 or 4 bits per parameter (we call this [\"quantization\"](https://huggingface.co/docs/transformers/main_classes/quantization)). Hence, if we apply LoRa to a quantized model (like a 4-bit model), then we call this QLoRa. We have a [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes) that tells you all about it. There are various quantization methods available, here we're going to use the [BitsandBytes](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig) integration.\n"
      ],
      "metadata": {
        "id": "7F9-BH4g9sr9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BitsAndBytesConfig\n",
        "import torch\n",
        "\n",
        "# specify how to quantize the model\n",
        "quantization_config = BitsAndBytesConfig(\n",
        "            load_in_4bit=True,\n",
        "            bnb_4bit_quant_type=\"nf4\",\n",
        "            bnb_4bit_compute_dtype=\"torch.bfloat16\",\n",
        ")\n",
        "device_map = {\"\": torch.cuda.current_device()} if torch.cuda.is_available() else None\n",
        "\n",
        "model_kwargs = dict(\n",
        "    attn_implementation=\"flash_attention_2\", # set this to True if your GPU supports it (Flash Attention drastically speeds up model computations)\n",
        "    torch_dtype=\"auto\",\n",
        "    use_cache=False, # set to False as we're going to use gradient checkpointing\n",
        "    device_map=device_map,\n",
        "    quantization_config=quantization_config,\n",
        ")"
      ],
      "metadata": {
        "id": "XrSQuIyu8Rt1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define SFTTrainer\n",
        "\n",
        "Next, we define the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) available in the TRL library. This class inherits from the Trainer class available in the Transformers library, but is specifically optimized for supervised fine-tuning (instruction tuning). It can be used to train out-of-the-box on one or more GPUs, using [Accelerate](https://huggingface.co/docs/accelerate/index) as backend.\n",
        "\n",
        "Most notably, it supports [packing](https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-), where multiple short examples are packed in the same input sequence to increase training efficiency.\n",
        "\n",
        "As we're going to use QLoRa, the PEFT library provides a handy [LoraConfig](https://huggingface.co/docs/peft/v0.7.1/en/package_reference/lora#peft.LoraConfig) which defines on which layers of the base model to apply the adapters. One typically applies LoRa on the linear projection matrices of the attention layers of a Transformer. We then provide this configuration to the SFTTrainer class. The weights of the base model will be loaded as we specify the `model_id` (this requires some time).\n",
        "\n",
        "We also specify various hyperparameters regarding training, such as:\n",
        "* we're going to fine-tune for 1 epoch\n",
        "* the learning rate and its scheduler\n",
        "* we're going to use gradient checkpointing (yet another way to save memory during training)\n",
        "* and so on."
      ],
      "metadata": {
        "id": "u5ZvdLgSABbk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from trl import SFTTrainer\n",
        "from peft import LoraConfig\n",
        "from transformers import TrainingArguments\n",
        "\n",
        "# path where the Trainer will save its checkpoints and logs\n",
        "output_dir = 'data/zephyr-7b-sft-lora'\n",
        "\n",
        "# based on config\n",
        "training_args = TrainingArguments(\n",
        "    fp16=True, # specify bf16=True instead when training on GPUs that support bf16\n",
        "    do_eval=True,\n",
        "    evaluation_strategy=\"epoch\",\n",
        "    gradient_accumulation_steps=128,\n",
        "    gradient_checkpointing=True,\n",
        "    gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
        "    learning_rate=2.0e-05,\n",
        "    log_level=\"info\",\n",
        "    logging_steps=5,\n",
        "    logging_strategy=\"steps\",\n",
        "    lr_scheduler_type=\"cosine\",\n",
        "    max_steps=-1,\n",
        "    num_train_epochs=1,\n",
        "    output_dir=output_dir,\n",
        "    overwrite_output_dir=True,\n",
        "    per_device_eval_batch_size=1, # originally set to 8\n",
        "    per_device_train_batch_size=1, # originally set to 8\n",
        "    # push_to_hub=True,\n",
        "    # hub_model_id=\"zephyr-7b-sft-lora\",\n",
        "    # hub_strategy=\"every_save\",\n",
        "    # report_to=\"tensorboard\",\n",
        "    save_strategy=\"no\",\n",
        "    save_total_limit=None,\n",
        "    seed=42,\n",
        ")\n",
        "\n",
        "# based on config\n",
        "peft_config = LoraConfig(\n",
        "        r=64,\n",
        "        lora_alpha=16,\n",
        "        lora_dropout=0.1,\n",
        "        bias=\"none\",\n",
        "        task_type=\"CAUSAL_LM\",\n",
        "        target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
        ")\n",
        "\n",
        "trainer = SFTTrainer(\n",
        "        model=model_id,\n",
        "        model_init_kwargs=model_kwargs,\n",
        "        args=training_args,\n",
        "        train_dataset=train_dataset,\n",
        "        eval_dataset=eval_dataset,\n",
        "        dataset_text_field=\"text\",\n",
        "        tokenizer=tokenizer,\n",
        "        packing=True,\n",
        "        peft_config=peft_config,\n",
        "        max_seq_length=tokenizer.model_max_length,\n",
        "    )"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 158,
          "referenced_widgets": [
            "ce1f77a753394dc5a25e5470fac18560",
            "7bb6256651a142eabc61984dfe5d379f",
            "12154fd312434260b7f6779a857e1a82",
            "2e44353a59c4480a8e877d842ad16061",
            "abdc9ab22ec049938855373effaf1504",
            "090b3eaf7d2548ee867fc7d9ddf67523",
            "2f2dd26e18ca47dfae4ff33dbb869c0f",
            "e5f6717710074184b78c30f4668be2b5",
            "222a8b16e19140269c44afffbca96865",
            "c3fd940f4dd34ce5b7a462e2bf6f1f71",
            "6264e61a163b4a5dbdb854f7e2ff3056"
          ]
        },
        "id": "W80YklLm_xAY",
        "outputId": "ced661c2-d638-4b4e-bc62-48ca240e2943"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:158: UserWarning: You passed a model_id to the SFTTrainer. This will automatically create an `AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ce1f77a753394dc5a25e5470fac18560"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:282: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
            "  warnings.warn(\n",
            "Using auto half precision backend\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Train!\n",
        "\n",
        "Finally, training is as simple as calling trainer.train()!"
      ],
      "metadata": {
        "id": "cGldALxQIwYu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "train_result = trainer.train()"
      ],
      "metadata": {
        "id": "HgEnI5KMIwyt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Saving the model\n",
        "\n",
        "Next, we save the Trainer's state. We also add the number of training samples to the logs."
      ],
      "metadata": {
        "id": "2xxjryHNBKD6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "metrics = train_result.metrics\n",
        "max_train_samples = training_args.max_train_samples if training_args.max_train_samples is not None else len(train_dataset)\n",
        "metrics[\"train_samples\"] = min(max_train_samples, len(train_dataset))\n",
        "trainer.log_metrics(\"train\", metrics)\n",
        "trainer.save_metrics(\"train\", metrics)\n",
        "trainer.save_state()"
      ],
      "metadata": {
        "id": "8Ai5jXhJBMsj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Inference\n",
        "\n",
        "Let's generate some new texts with our trained model.\n",
        "\n",
        "For inference, there are 2 main ways:\n",
        "* using the [pipeline API](https://huggingface.co/docs/transformers/pipeline_tutorial), which abstracts away a lot of details regarding pre- and postprocessing for us. [This model card](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta#intended-uses--limitations) for instance illustrates this.\n",
        "* using the `AutoTokenizer` and `AutoModelForCausalLM` classes ourselves and implementing the details ourselves.\n",
        "\n",
        "Let us do the latter, so that we understand what's going on.\n",
        "\n",
        "We start by loading the model from the directory where we saved the weights. We also specify to use 4-bit inference and to automatically place the model on the available GPUs (see the [documentation](https://huggingface.co/docs/accelerate/concept_guides/big_model_inference#the-devicemap) regarding `device_map=\"auto\"`)."
      ],
      "metadata": {
        "id": "-tCZxj1tBNAc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(output_dir)\n",
        "model = AutoModelForCausalLM.from_pretrained(output_dir, load_in_4bit=True, device_map=\"auto\")"
      ],
      "metadata": {
        "id": "yiRvmsSkyubH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Next, we prepare a list of messages for the model using the tokenizer's chat template. Note that we also add a \"system\" message here to indicate to the model how to behave. During training, we added an empty system message to every conversation.\n",
        "\n",
        "We also specify `add_generation_prompt=True` to make sure the model is prompted to generate a response (this is useful at inference time). We specify \"cuda\" to move the inputs to the GPU. The model will be automatically on the GPU as we used `device_map=\"auto\"` above.\n",
        "\n",
        "Next, we use the [generate()](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) method to autoregressively generate the next token IDs, one after the other. Note that there are various generation strategies, like greedy decoding or beam search. Refer to [this blog post](https://huggingface.co/blog/how-to-generate) for all details. Here we use sampling.\n",
        "\n",
        "Finally, we use the batch_decode method of the tokenizer to turn the generated token IDs back into strings."
      ],
      "metadata": {
        "id": "X7mfwoFnC5zW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "\n",
        "# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating\n",
        "messages = [\n",
        "    {\n",
        "        \"role\": \"system\",\n",
        "        \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n",
        "    },\n",
        "    {\"role\": \"user\", \"content\": \"How many helicopters can a human eat in one sitting?\"},\n",
        "]\n",
        "\n",
        "# prepare the messages for the model\n",
        "input_ids = tokenizer.apply_chat_template(messages, truncation=True, add_generation_prompt=True, return_tensors=\"pt\").to(\"cuda\")\n",
        "\n",
        "# inference\n",
        "outputs = model.generate(\n",
        "        input_ids=input_ids,\n",
        "        max_new_tokens=256,\n",
        "        do_sample=True,\n",
        "        temperature=0.7,\n",
        "        top_k=50,\n",
        "        top_p=0.95\n",
        ")\n",
        "print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])"
      ],
      "metadata": {
        "id": "Hkacv5PvBOvE"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}